/*
    +--------------------------------------------------------------------+
    | libmemcached-awesome - C/C++ Client Library for memcached          |
    +--------------------------------------------------------------------+
    | Redistribution and use in source and binary forms, with or without |
    | modification, are permitted under the terms of the BSD license.    |
    | You should have received a copy of the license in a bundled file   |
    | named LICENSE; in case you did not receive a copy you can review   |
    | the terms online at: https://opensource.org/licenses/BSD-3-Clause  |
    +--------------------------------------------------------------------+
    | Copyright (c) 2006-2014 Brian Aker   https://datadifferential.com/ |
    | Copyright (c) 2020-2021 Michael Wallner        https://awesome.co/ |
    +--------------------------------------------------------------------+
*/

#include "libmemcached/common.h"

#include <cassert>
#include <atomic>

#if defined(LIBMEMCACHED_WITH_SASL_SUPPORT) && LIBMEMCACHED_WITH_SASL_SUPPORT

#  if defined(HAVE_LIBSASL) && HAVE_LIBSASL
#    include <sasl/sasl.h>
#  endif

#  define CAST_SASL_CB(cb) reinterpret_cast<int (*)()>(reinterpret_cast<intptr_t>(cb))

#  include <pthread.h>

void memcached_set_sasl_callbacks(memcached_st *shell, const sasl_callback_t *callbacks) {
  Memcached *self = memcached2Memcached(shell);
  if (self) {
    self->sasl.callbacks = const_cast<sasl_callback_t *>(callbacks);
    self->sasl.is_allocated = false;
  }
}

sasl_callback_t *memcached_get_sasl_callbacks(memcached_st *shell) {
  Memcached *self = memcached2Memcached(shell);
  if (self) {
    return self->sasl.callbacks;
  }

  return NULL;
}

/**
 * Resolve the names for both ends of a connection
 * @param fd socket to check
 * @param laddr local address (out)
 * @param raddr remote address (out)
 * @return true on success false otherwise (errno contains more info)
 */
static memcached_return_t resolve_names(memcached_instance_st &server, char *laddr,
                                        size_t laddr_length, char *raddr, size_t raddr_length) {
  char host[MEMCACHED_NI_MAXHOST];
  char port[MEMCACHED_NI_MAXSERV];
  struct sockaddr_storage saddr;
  socklen_t salen = sizeof(saddr);

  if (getsockname(server.fd, (struct sockaddr *) &saddr, &salen) < 0) {
    return memcached_set_error(server, MEMCACHED_HOST_LOOKUP_FAILURE, MEMCACHED_AT);
  }

  if (getnameinfo((struct sockaddr *) &saddr, salen, host, sizeof(host), port, sizeof(port),
                  NI_NUMERICHOST | NI_NUMERICSERV)
      < 0)
  {
    return memcached_set_error(server, MEMCACHED_HOST_LOOKUP_FAILURE, MEMCACHED_AT);
  }

  (void) snprintf(laddr, laddr_length, "%s;%s", host, port);
  salen = sizeof(saddr);

  if (getpeername(server.fd, (struct sockaddr *) &saddr, &salen) < 0) {
    return memcached_set_error(server, MEMCACHED_HOST_LOOKUP_FAILURE, MEMCACHED_AT);
  }

  if (getnameinfo((struct sockaddr *) &saddr, salen, host, sizeof(host), port, sizeof(port),
                  NI_NUMERICHOST | NI_NUMERICSERV)
      < 0)
  {
    return memcached_set_error(server, MEMCACHED_HOST_LOOKUP_FAILURE, MEMCACHED_AT);
  }

  (void) snprintf(raddr, raddr_length, "%s;%s", host, port);

  return MEMCACHED_SUCCESS;
}

extern "C" {

static void sasl_shutdown_function() {
#if HAVE_SASL_CLIENT_DONE
  (void) sasl_client_done();
#else
  sasl_done();
#endif
}

static std::atomic<int> sasl_startup_state(SASL_OK);
static pthread_mutex_t sasl_startup_state_LOCK = PTHREAD_MUTEX_INITIALIZER;
static pthread_once_t sasl_startup_once = PTHREAD_ONCE_INIT;
static void sasl_startup_function(void) {
  sasl_startup_state = sasl_client_init(NULL);

  if (sasl_startup_state == SASL_OK) {
    (void) atexit(sasl_shutdown_function);
  }
}

} // extern "C"

memcached_return_t memcached_sasl_authenticate_connection(memcached_instance_st *server) {
  if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0) {
    return MEMCACHED_NOT_SUPPORTED;
  }

  if (server == NULL) {
    return MEMCACHED_INVALID_ARGUMENTS;
  }

  /* SANITY CHECK: SASL can only be used with the binary protocol */
  if (memcached_is_binary(server->root) == false) {
    return memcached_set_error(
        *server, MEMCACHED_INVALID_ARGUMENTS, MEMCACHED_AT,
        memcached_literal_param(
            "memcached_sasl_authenticate_connection() is not supported via the ASCII protocol"));
  }

  /* Try to get the supported mech from the server. Servers without SASL
   * support will return UNKNOWN COMMAND, so we can just treat that
   * as authenticated
   */
  protocol_binary_request_no_extras request = {};

  initialize_binary_request(server, request.message.header);

  request.message.header.request.opcode = PROTOCOL_BINARY_CMD_SASL_LIST_MECHS;

  if (memcached_io_write(server, request.bytes, sizeof(request.bytes), true)
      != sizeof(request.bytes)) {
    return MEMCACHED_WRITE_FAILURE;
  }
  assert_msg(server->fd != INVALID_SOCKET, "Programmer error, invalid socket");

  memcached_server_response_increment(server);

  char mech[MEMCACHED_MAX_BUFFER] = {0};
  memcached_return_t rc = memcached_response(server, mech, sizeof(mech) - 1, NULL);
  if (memcached_failed(rc)) {
    if (rc == MEMCACHED_PROTOCOL_ERROR) {
      /* If the server doesn't support SASL it will return PROTOCOL_ERROR.
       * This error may also be returned for other errors, but let's assume
       * that the server don't support SASL and treat it as success and
       * let the client fail with the next operation if the error was
       * caused by another problem....
       */
      rc = MEMCACHED_SUCCESS;
    }

    return rc;
  }
  assert_msg(server->fd != INVALID_SOCKET, "Programmer error, invalid socket");

  /* set ip addresses */
  char laddr[MEMCACHED_NI_MAXHOST + MEMCACHED_NI_MAXSERV];
  char raddr[MEMCACHED_NI_MAXHOST + MEMCACHED_NI_MAXSERV];

  if (memcached_failed(rc = resolve_names(*server, laddr, sizeof(laddr), raddr, sizeof(raddr)))) {
    return rc;
  }

  int pthread_error;
  if ((pthread_error = pthread_once(&sasl_startup_once, sasl_startup_function))) {
    return memcached_set_errno(*server, pthread_error, MEMCACHED_AT);
  }

  (void) pthread_mutex_lock(&sasl_startup_state_LOCK);
  if (sasl_startup_state != SASL_OK) {
    const char *sasl_error_msg = sasl_errstring(sasl_startup_state, NULL, NULL);
    return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT,
                               memcached_string_make_from_cstr(sasl_error_msg));
  }
  (void) pthread_mutex_unlock(&sasl_startup_state_LOCK);

  sasl_conn_t *conn;
  int ret;
  if ((ret = sasl_client_new("memcached", server->_hostname, laddr, raddr,
                             server->root->sasl.callbacks, 0, &conn))
      != SASL_OK)
  {
    const char *sasl_error_msg = sasl_errstring(ret, NULL, NULL);

    sasl_dispose(&conn);

    return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT,
                               memcached_string_make_from_cstr(sasl_error_msg));
  }

  const char *data;
  const char *chosenmech;
  unsigned int len;
  ret = sasl_client_start(conn, mech, NULL, &data, &len, &chosenmech);
  if (ret != SASL_OK and ret != SASL_CONTINUE) {
    const char *sasl_error_msg = sasl_errstring(ret, NULL, NULL);

    sasl_dispose(&conn);

    return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT,
                               memcached_string_make_from_cstr(sasl_error_msg));
  }
  uint16_t keylen = (uint16_t) strlen(chosenmech);
  request.message.header.request.opcode = PROTOCOL_BINARY_CMD_SASL_AUTH;
  request.message.header.request.keylen = htons(keylen);
  request.message.header.request.bodylen = htonl(len + keylen);

  do {
    /* send the packet */

    libmemcached_io_vector_st vector[] = {
        {request.bytes, sizeof(request.bytes)}, {chosenmech, keylen}, {data, len}};

    assert_msg(server->fd != INVALID_SOCKET, "Programmer error, invalid socket");
    if (memcached_io_writev(server, vector, 3, true) == false) {
      rc = MEMCACHED_WRITE_FAILURE;
      break;
    }
    assert_msg(server->fd != INVALID_SOCKET, "Programmer error, invalid socket");
    memcached_server_response_increment(server);

    /* read the response */
    assert_msg(server->fd != INVALID_SOCKET, "Programmer error, invalid socket");
    rc = memcached_response(server, NULL, 0, NULL);
    if (rc != MEMCACHED_AUTH_CONTINUE) {
      break;
    }
    assert_msg(server->fd != INVALID_SOCKET, "Programmer error, invalid socket");

    ret = sasl_client_step(conn, memcached_result_value(&server->root->result),
                           (unsigned int) memcached_result_length(&server->root->result), NULL,
                           &data, &len);

    if (ret != SASL_OK && ret != SASL_CONTINUE) {
      rc = MEMCACHED_AUTH_PROBLEM;
      break;
    }

    request.message.header.request.opcode = PROTOCOL_BINARY_CMD_SASL_STEP;
    request.message.header.request.bodylen = htonl(len + keylen);
  } while (true);

  /* Release resources */
  sasl_dispose(&conn);

  return memcached_set_error(*server, rc, MEMCACHED_AT);
}

static int get_username(void *context, int id, const char **result, unsigned int *len) {
  if (!context || !result || (id != SASL_CB_USER && id != SASL_CB_AUTHNAME)) {
    return SASL_BADPARAM;
  }

  *result = (char *) context;
  if (len) {
    *len = (unsigned int) strlen(*result);
  }

  return SASL_OK;
}

static int get_password(sasl_conn_t *conn, void *context, int id, sasl_secret_t **psecret) {
  if (!conn || !psecret || id != SASL_CB_PASS) {
    return SASL_BADPARAM;
  }

  *psecret = (sasl_secret_t *) context;

  return SASL_OK;
}

memcached_return_t memcached_set_sasl_auth_data(memcached_st *shell, const char *username,
                                                const char *password) {
  Memcached *ptr = memcached2Memcached(shell);
  if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0) {
    return MEMCACHED_NOT_SUPPORTED;
  }

  if (ptr == NULL or username == NULL or password == NULL) {
    return MEMCACHED_INVALID_ARGUMENTS;
  }

  memcached_return_t ret;
  if (memcached_failed(ret = memcached_behavior_set(ptr, MEMCACHED_BEHAVIOR_BINARY_PROTOCOL, 1))) {
    return memcached_set_error(
        *ptr, ret, MEMCACHED_AT,
        memcached_literal_param("Unable change to binary protocol which is required for SASL."));
  }

  memcached_destroy_sasl_auth_data(ptr);

  sasl_callback_t *callbacks = libmemcached_xcalloc(ptr, 4, sasl_callback_t);
  size_t password_length = strlen(password);
  size_t username_length = strlen(username);
  char *name = (char *) libmemcached_malloc(ptr, username_length + 1);
  sasl_secret_t *secret =
      (sasl_secret_t *) libmemcached_malloc(ptr, password_length + 1 + sizeof(sasl_secret_t));

  if (callbacks == NULL or name == NULL or secret == NULL) {
    libmemcached_free(ptr, callbacks);
    libmemcached_free(ptr, name);
    libmemcached_free(ptr, secret);
    return memcached_set_error(*ptr, MEMCACHED_MEMORY_ALLOCATION_FAILURE, MEMCACHED_AT);
  }

  secret->len = password_length;
  memcpy(secret->data, password, password_length);
  secret->data[password_length] = 0;

  callbacks[0].id = SASL_CB_USER;
  callbacks[0].proc = CAST_SASL_CB(get_username);
  callbacks[0].context = strncpy(name, username, username_length + 1);
  callbacks[1].id = SASL_CB_AUTHNAME;
  callbacks[1].proc = CAST_SASL_CB(get_username);
  callbacks[1].context = name;
  callbacks[2].id = SASL_CB_PASS;
  callbacks[2].proc = CAST_SASL_CB(get_password);
  callbacks[2].context = secret;
  callbacks[3].id = SASL_CB_LIST_END;

  ptr->sasl.callbacks = callbacks;
  ptr->sasl.is_allocated = true;

  return MEMCACHED_SUCCESS;
}

memcached_return_t memcached_destroy_sasl_auth_data(memcached_st *shell) {
  if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0) {
    return MEMCACHED_NOT_SUPPORTED;
  }

  Memcached *ptr = memcached2Memcached(shell);
  if (ptr == NULL) {
    return MEMCACHED_INVALID_ARGUMENTS;
  }

  if (ptr->sasl.callbacks == NULL) {
    return MEMCACHED_SUCCESS;
  }

  if (ptr->sasl.is_allocated) {
    libmemcached_free(ptr, ptr->sasl.callbacks[0].context);
    libmemcached_free(ptr, ptr->sasl.callbacks[2].context);
    libmemcached_free(ptr, (void *) ptr->sasl.callbacks);
    ptr->sasl.is_allocated = false;
  }

  ptr->sasl.callbacks = NULL;

  return MEMCACHED_SUCCESS;
}

memcached_return_t memcached_clone_sasl(memcached_st *clone, const memcached_st *source) {
  if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0) {
    return MEMCACHED_NOT_SUPPORTED;
  }

  if (clone == NULL or source == NULL) {
    return MEMCACHED_INVALID_ARGUMENTS;
  }

  if (source->sasl.callbacks == NULL) {
    return MEMCACHED_SUCCESS;
  }

  /* Hopefully we are using our own callback mechanisms.. */
  if (source->sasl.callbacks[0].id == SASL_CB_USER
      && source->sasl.callbacks[0].proc == CAST_SASL_CB(get_username)
      && source->sasl.callbacks[1].id == SASL_CB_AUTHNAME
      && source->sasl.callbacks[1].proc == CAST_SASL_CB(get_username)
      && source->sasl.callbacks[2].id == SASL_CB_PASS
      && source->sasl.callbacks[2].proc == CAST_SASL_CB(get_password)
      && source->sasl.callbacks[3].id == SASL_CB_LIST_END)
  {
    sasl_secret_t *secret = (sasl_secret_t *) source->sasl.callbacks[2].context;
    return memcached_set_sasl_auth_data(clone, (const char *) source->sasl.callbacks[0].context,
                                        (const char *) secret->data);
  }

  /*
   * But we're not. It may work if we know what the user tries to pass
   * into the list, but if we don't know the ID we don't know how to handle
   * the context...
   */
  ptrdiff_t total = 0;

  while (source->sasl.callbacks[total].id != SASL_CB_LIST_END) {
    switch (source->sasl.callbacks[total].id) {
    case SASL_CB_USER:
    case SASL_CB_AUTHNAME:
    case SASL_CB_PASS:
      break;
    default:
      /* I don't know how to deal with this... */
      return MEMCACHED_NOT_SUPPORTED;
    }

    ++total;
  }

  sasl_callback_t *callbacks = libmemcached_xcalloc(clone, total + 1, sasl_callback_t);
  if (callbacks == NULL) {
    return MEMCACHED_MEMORY_ALLOCATION_FAILURE;
  }
  memcpy(callbacks, source->sasl.callbacks, (total + 1) * sizeof(sasl_callback_t));

  /* Now update the context... */
  for (ptrdiff_t x = 0; x < total; ++x) {
    if (callbacks[x].id == SASL_CB_USER || callbacks[x].id == SASL_CB_AUTHNAME) {
      callbacks[x].context = (sasl_callback_t *) libmemcached_malloc(
          clone, strlen((const char *) source->sasl.callbacks[x].context));

      if (callbacks[x].context == NULL) {
        /* Failed to allocate memory, clean up previously allocated memory */
        for (ptrdiff_t y = 0; y < x; ++y) {
          libmemcached_free(clone, clone->sasl.callbacks[y].context);
        }

        libmemcached_free(clone, callbacks);
        return MEMCACHED_MEMORY_ALLOCATION_FAILURE;
      }
      strncpy((char *) callbacks[x].context, (const char *) source->sasl.callbacks[x].context,
              sizeof(callbacks[x].context));
    } else {
      sasl_secret_t *src = (sasl_secret_t *) source->sasl.callbacks[x].context;
      sasl_secret_t *n = (sasl_secret_t *) libmemcached_malloc(clone, src->len + 1 + sizeof(*n));
      if (n == NULL) {
        /* Failed to allocate memory, clean up previously allocated memory */
        for (ptrdiff_t y = 0; y < x; ++y) {
          libmemcached_free(clone, clone->sasl.callbacks[y].context);
        }

        libmemcached_free(clone, callbacks);
        return MEMCACHED_MEMORY_ALLOCATION_FAILURE;
      }
      memcpy(n, src, src->len + 1 + sizeof(*n));
      callbacks[x].context = n;
    }
  }

  clone->sasl.callbacks = callbacks;
  clone->sasl.is_allocated = true;

  return MEMCACHED_SUCCESS;
}

#else

void memcached_set_sasl_callbacks(memcached_st *, const sasl_callback_t *) {}

sasl_callback_t *memcached_get_sasl_callbacks(memcached_st *) {
  return NULL;
}

memcached_return_t memcached_set_sasl_auth_data(memcached_st *, const char *, const char *) {
  return MEMCACHED_NOT_SUPPORTED;
}

memcached_return_t memcached_clone_sasl(memcached_st *, const memcached_st *) {
  return MEMCACHED_NOT_SUPPORTED;
}

memcached_return_t memcached_destroy_sasl_auth_data(memcached_st *) {
  return MEMCACHED_NOT_SUPPORTED;
}

memcached_return_t memcached_sasl_authenticate_connection(memcached_instance_st *) {
  return MEMCACHED_NOT_SUPPORTED;
}

#endif
